{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Copyright (c) Microsoft Corporation.\n",
    "Licensed under the MIT License."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Abstractive Summarization using BertSumAbs on CNN/DailyMails Dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook demonstrates how to fine tune BERT for abstractive text summarization. Utility functions and classes in the NLP Best Practices repo are used to facilitate data preprocessing, model training, model scoring, result postprocessing, and model evaluation.\n",
    "\n",
    "### Abstractive Summarization\n",
    "Abstractive summarization is the task of taking an input text and summarizing its content in a shorter output text. In contrast to extractive summarization, abstractive summarization doesn't take sentences directly from the input text, instead, rephrases the input text.\n",
    "\n",
    "### BertSumAbs\n",
    "\n",
    "BertSumAbs refers to an BERT-based abstractive summarization algorithm  in [Text Summarization with Pretrained Encoders](https://arxiv.org/abs/1908.08345) with [published examples](https://github.com/nlpyang/PreSumm). It uses the pretrained BERT model as encoder and finetune both encoder and decoder on a specific labeled summarization dataset like [CNN/DM dataset](https://github.com/harvardnlp/sent-summary). \n",
    "\n",
    "The figure below shows the comparison of architecture of the original BERT model (left) and BERTSUM (right), which BertSumAbs is built upon. For BERTSUM, a input document is split into sentences, and [CLS] and [SEP] tokens are inserted before and after each sentence. This resulting sequence is followed by the summation of three kinds of embeddings for each token before feeding into the transformer layers. The positional embedding used in BertSumAbs enables input length of more than 512, which is the  maximum input length for BERT model. \n",
    "\n",
    "It should be noted that the architecture only shows the encoder part. For decoder, BertSumAbs also uses a transformer with multiple layers and random initialization. As pretrained weights are used in the encoder, there is a mismatch in encoder and decoder which may result in unstable finetuning. Therefore, in fine tuning, BertSumAbs uses seperate optimizers for encoder and decoder, each uses its own scheduling. In text generation, techniques like trigram blocking and beam search can be used to improve model accuracy.\n",
    "<img src=\"https://nlpbp.blob.core.windows.net/images/BertForSummarization.PNG\">\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Before you start\n",
    "\n",
    "It's recommended to run this notebook on GPU machines as it's very computationally intensive. Set QUICK_RUN = True to run the notebook on a small subset of data and a smaller number of steps. If QUICK_RUN = False, the notebook takes about 5 hours to run on a VM with 4 16GB NVIDIA V100 GPUs. Finetuning costs around 1.5 hours and inferecing costs around 3.5 hour.  Better performance can be achieved by increasing the MAX_STEPS.\n",
    "\n",
    "* **ROUGE Evalation**: To run rouge evaluation, please refer to the section of compute_rouge_perl in [summarization_evaluation.ipynb](./summarization_evaluation.ipynb) for setup.\n",
    "\n",
    "* **Distributed Training**:\n",
    "Please note that the jupyter notebook only allows to use pytorch [DataParallel](https://pytorch.org/docs/master/nn.html#dataparallel). Faster speed and larger batch size can be achieved with pytorch [DistributedDataParallel](https://pytorch.org/docs/master/notes/ddp.html)(DDP). Script [abstractive_summarization_bertsum_cnndm_distributed_train.py](./abstractive_summarization_bertsum_cnndm_distributed_train.py) shows an example of how to use DDP.\n",
    "\n",
    "* **Mixed Precision Training**:\n",
    "Please note that by default this notebook doesn't use mixed precision training. Faster speed and larger batch size can be achieved when you set FP16 to True. Refer to  https://nvidia.github.io/apex and https://github.com/nvidia/apex) for details to use mixed precision training. Check the GPU model on your machine to see if it allows mixed precision training. Please also note that mixed precision inferencing is also enabled in the prediciton utility function. When you use mixed precision training and/or inferencing, the model performance can be slightly worse than the full precision mode."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "QUICK_RUN = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import shutil\n",
    "import sys\n",
    "from tempfile import TemporaryDirectory\n",
    "import torch\n",
    "\n",
    "nlp_path = os.path.abspath(\"../../\")\n",
    "if nlp_path not in sys.path:\n",
    "    sys.path.insert(0, nlp_path)\n",
    "\n",
    "from utils_nlp.models.transformers.abstractive_summarization_bertsum import (\n",
    "    BertSumAbs,\n",
    "    BertSumAbsProcessor,\n",
    ")\n",
    "\n",
    "from utils_nlp.dataset.cnndm import CNNDMSummarizationDataset\n",
    "from utils_nlp.eval import compute_rouge_python\n",
    "\n",
    "from utils_nlp.models.transformers.datasets import SummarizationDataset\n",
    "import nltk\n",
    "from nltk import tokenize\n",
    "\n",
    "import pandas as pd\n",
    "import pprint\n",
    "import scrapbook as sb"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data Preprocessing"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The dataset we used for this notebook is CNN/DM dataset which contains the documents and accompanying questions from the news articles of CNN and Daily mail. The highlights in each article are used as summary. The dataset consits of ~289K training examples, ~11K valiation examples and ~11K test examples. The length of the news articles is 781 tokens on average and the summaries are of 3.75 sentences and 56 tokens on average.\n",
    "\n",
    "The significant part of data preprocessing only involve splitting the input document into sentences."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# the data path used to save the downloaded data file\n",
    "DATA_PATH = TemporaryDirectory().name\n",
    "# The number of lines at the head of data file used for preprocessing. -1 means all the lines.\n",
    "TOP_N = 100\n",
    "if not QUICK_RUN:\n",
    "    TOP_N = -1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset, test_dataset = CNNDMSummarizationDataset(\n",
    "    top_n=TOP_N, local_cache_path=DATA_PATH, prepare_extractive=False\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(train_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "len(test_dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model Finetuning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "# notebook parameters\n",
    "# the cache path\n",
    "CACHE_PATH = TemporaryDirectory().name\n",
    "\n",
    "# model parameters\n",
    "MODEL_NAME = \"bert-base-uncased\"\n",
    "MAX_POS = 768\n",
    "MAX_SOURCE_SEQ_LENGTH = 640\n",
    "MAX_TARGET_SEQ_LENGTH = 140\n",
    "\n",
    "# mixed precision setting. To enable mixed precision training, follow instructions in SETUP.md.\n",
    "FP16 = False\n",
    "if FP16:\n",
    "    FP16_OPT_LEVEL = \"O2\"\n",
    "\n",
    "# fine-tuning parameters\n",
    "# batch size, unit is the number of tokens\n",
    "BATCH_SIZE_PER_GPU = 1\n",
    "\n",
    "\n",
    "# GPU used for training\n",
    "NUM_GPUS = torch.cuda.device_count()\n",
    "if NUM_GPUS > 0:\n",
    "    BATCH_SIZE = NUM_GPUS * BATCH_SIZE_PER_GPU\n",
    "else:\n",
    "    BATCH_SIZE = 1\n",
    "\n",
    "\n",
    "# Learning rate\n",
    "LEARNING_RATE_BERT = 5e-4 / 2.0\n",
    "LEARNING_RATE_DEC = 0.05 / 2.0\n",
    "\n",
    "\n",
    "# How often the statistics reports show up in training, unit is step.\n",
    "REPORT_EVERY = 10\n",
    "SAVE_EVERY = 500\n",
    "\n",
    "# total number of steps for training\n",
    "MAX_STEPS = 1e3\n",
    "\n",
    "if not QUICK_RUN:\n",
    "    MAX_STEPS = 5e3\n",
    "\n",
    "WARMUP_STEPS_BERT = 2000\n",
    "WARMUP_STEPS_DEC = 1000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# processor which contains the colloate function to load the preprocessed data\n",
    "processor = BertSumAbsProcessor(cache_dir=CACHE_PATH, max_src_len=MAX_SOURCE_SEQ_LENGTH, max_tgt_len=MAX_TARGET_SEQ_LENGTH)\n",
    "# summarizer\n",
    "summarizer = BertSumAbs(\n",
    "    processor, cache_dir=CACHE_PATH, max_pos_length=MAX_POS\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "BATCH_SIZE_PER_GPU*NUM_GPUS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "summarizer.fit(\n",
    "    train_dataset,\n",
    "    num_gpus=NUM_GPUS,\n",
    "    batch_size=BATCH_SIZE,\n",
    "    max_steps=MAX_STEPS,\n",
    "    learning_rate_bert=LEARNING_RATE_BERT,\n",
    "    learning_rate_dec=LEARNING_RATE_DEC,\n",
    "    warmup_steps_bert=WARMUP_STEPS_BERT,\n",
    "    warmup_steps_dec=WARMUP_STEPS_DEC,\n",
    "    save_every=SAVE_EVERY,\n",
    "    report_every=REPORT_EVERY * 5,\n",
    "    fp16=FP16,\n",
    "    # checkpoint=\"saved checkpoint path\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "summarizer.save_model(MAX_STEPS, os.path.join(CACHE_PATH, \"bertsumabs.pt\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model Evaluation\n",
    "\n",
    "To run rouge evaluation, please refer to the section of compute_rouge_perl in [summarization_evaluation.ipynb](summarization_evaluation.ipynb) for setup.\n",
    "For the settings in this notebook with QUICK_RUN=False, you should get ROUGE scores close to the following numbers: <br />\n",
    "``\n",
    "{'rouge-1': {'f': 0.34819639878321873,\n",
    "             'p': 0.39977932634737307,\n",
    "             'r': 0.34429079596863604},\n",
    " 'rouge-2': {'f': 0.13919271352557894,\n",
    "             'p': 0.16129965067780644,\n",
    "             'r': 0.1372938054050938},\n",
    " 'rouge-l': {'f': 0.2313282318854973,\n",
    "             'p': 0.26664667422849747,\n",
    "             'r': 0.22850294283399628}}\n",
    " ``\n",
    " \n",
    " Better performance can be achieved by increasing the MAX_STEPS."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "\n",
    "# checkpoint = torch.load(os.path.join(CACHE_PATH, \"bertsumabs.pt\"), map_location=\"cpu\")\n",
    "# summarizer = BertSumAbs(\n",
    "#     processor, cache_dir=CACHE_PATH, max_pos_length=MAX_POS, test=True\n",
    "# )\n",
    "# summarizer.model.load_checkpoint(checkpoint['model'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "TEST_TOP_N = 32\n",
    "if not QUICK_RUN:\n",
    "    TEST_TOP_N = len(test_dataset)\n",
    "\n",
    "if NUM_GPUS:\n",
    "    BATCH_SIZE = NUM_GPUS * BATCH_SIZE_PER_GPU\n",
    "else:\n",
    "    BATCH_SIZE = 1\n",
    "    \n",
    "shortened_dataset = test_dataset.shorten(top_n=TEST_TOP_N)\n",
    "src = shortened_dataset.get_source()\n",
    "reference_summaries = [\" \".join(t).rstrip(\"\\n\") for t in shortened_dataset.get_target()]\n",
    "generated_summaries = summarizer.predict(\n",
    "    shortened_dataset, batch_size=BATCH_SIZE, num_gpus=NUM_GPUS\n",
    ")\n",
    "assert len(generated_summaries) == len(reference_summaries)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "src[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "generated_summaries[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reference_summaries[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rouge_scores = compute_rouge_python(cand=generated_summaries, ref=reference_summaries)\n",
    "pprint.pprint(rouge_scores)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# for testing\n",
    "sb.glue(\"rouge_2_f_score\", rouge_scores['rouge-2']['f'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prediction on a single input sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "source = \"\"\"\n",
    "But under the new rule, set to be announced in the next 48 hours, Border Patrol agents would immediately return anyone to Mexico — without any detainment and without any due process — who attempts to cross the southwestern border between the legal ports of entry. The person would not be held for any length of time in an American facility.\n",
    "\n",
    "Although they advised that details could change before the announcement, administration officials said the measure was needed to avert what they fear could be a systemwide outbreak of the coronavirus inside detention facilities along the border. Such an outbreak could spread quickly through the immigrant population and could infect large numbers of Border Patrol agents, leaving the southwestern border defenses weakened, the officials argued.\n",
    "The Trump administration plans to immediately turn back all asylum seekers and other foreigners attempting to enter the United States from Mexico illegally, saying the nation cannot risk allowing the coronavirus to spread through detention facilities and Border Patrol agents, four administration officials said.\n",
    "The administration officials said the ports of entry would remain open to American citizens, green-card holders and foreigners with proper documentation. Some foreigners would be blocked, including Europeans currently subject to earlier travel restrictions imposed by the administration. The points of entry will also be open to commercial traffic.\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "test_dataset = SummarizationDataset(\n",
    "    None, source=[source], source_preprocessing=[tokenize.sent_tokenize],\n",
    ")\n",
    "generated_summaries = summarizer.predict(test_dataset, batch_size=1, num_gpus=NUM_GPUS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "generated_summaries[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Clean up temporary folders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if os.path.exists(DATA_PATH):\n",
    "    shutil.rmtree(DATA_PATH, ignore_errors=True)\n",
    "if os.path.exists(CACHE_PATH):\n",
    "    shutil.rmtree(CACHE_PATH, ignore_errors=True)"
   ]
  }
 ],
 "metadata": {
  "celltoolbar": "Tags",
  "kernelspec": {
   "display_name": "Python (nlp_gpu)",
   "language": "python",
   "name": "nlp_gpu"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}